{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SAC的简化版实现,alpha使用常量代替.只使用一个value模型,而不是两个."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcP0lEQVR4nO3de3BU9d0/8Pdu9pLL5uySQHaJJJKnoJDhUg0YVmdqKynRplZq+oxlGE0tPx3pwoB0mJpW8KnTmfDDaa1WhM44Ff7BdLAFKwU1T9CgwwoYSQ231P4GmzzAbriY3RDIbnb38/tDch5WLu5Cdr/Z9f2aOTPknO8m70PYN2fP1SAiAiKiNDOqDkBEX08sHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUkJZ+axbtw4TJ05Ebm4uqqursW/fPlVRiEgBJeXz5z//GStWrMAzzzyDjz/+GDNnzkRtbS16e3tVxCEiBQwqLiytrq7G7Nmz8dJLLwEAYrEYysrKsHTpUjz11FNf+fpYLIYTJ06gsLAQBoMh1XGJKEEigv7+fpSWlsJovPa2jSlNmXThcBjt7e1obGzU5xmNRtTU1MDr9V7xNaFQCKFQSP/6+PHjqKysTHlWIro+PT09mDBhwjXHpL18Tp8+jWg0CqfTGTff6XTi6NGjV3xNU1MTfv3rX182v6enB5qmpSQnESUvGAyirKwMhYWFXzk27eVzPRobG7FixQr96+EV1DSN5UM0CiWyOyTt5TN27Fjk5OTA7/fHzff7/XC5XFd8jdVqhdVqTUc8IkqTtB/tslgsqKqqQmtrqz4vFouhtbUVbrc73XGISBElH7tWrFiBhoYGzJo1C3fccQd+//vfY2BgAI8++qiKOESkgJLyeeihh3Dq1CmsXr0aPp8P3/zmN/HWW29dthOaiLKXkvN8blQwGITdbkcgEOAOZ6JRJJn3Jq/tIiIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlEi6fHbv3o37778fpaWlMBgM2LZtW9xyEcHq1asxfvx45OXloaamBp9++mncmLNnz2LhwoXQNA0OhwOLFi3CuXPnbmhFiCizJF0+AwMDmDlzJtatW3fF5WvXrsWLL76IDRs2YO/evSgoKEBtbS0GBwf1MQsXLsShQ4fQ0tKC7du3Y/fu3Xj88cevfy2IKPPIDQAgW7du1b+OxWLicrnkueee0+f19fWJ1WqV1157TUREDh8+LABk//79+pidO3eKwWCQ48ePJ/RzA4GAAJBAIHAj8YlohCXz3hzRfT7Hjh2Dz+dDTU2NPs9ut6O6uhperxcA4PV64XA4MGvWLH1MTU0NjEYj9u7de8XvGwqFEAwG4yYiymwjWj4+nw8A4HQ64+Y7nU59mc/nQ0lJSdxyk8mEoqIifcyXNTU1wW6361NZWdlIxiYiBTLiaFdjYyMCgYA+9fT0qI5ERDdoRMvH5XIBAPx+f9x8v9+vL3O5XOjt7Y1bHolEcPbsWX3Ml1mtVmiaFjcRUWYb0fKpqKiAy+VCa2urPi8YDGLv3r1wu90AALfbjb6+PrS3t+tjdu3ahVgshurq6pGMQ0SjmCnZF5w7dw7/+te/9K+PHTuGjo4OFBUVoby8HMuXL8dvfvMbTJ48GRUVFVi1ahVKS0sxf/58AMDUqVNx77334rHHHsOGDRswNDSEJUuW4Mc//jFKS0tHbMWIaJRL9lDau+++KwAumxoaGkTki8Ptq1atEqfTKVarVebOnStdXV1x3+PMmTOyYMECsdlsommaPProo9Lf359wBh5qJxqdknlvGkREFHbfdQkGg7Db7QgEAtz/QzSKJPPezIijXUSUfVg+RKQEy4eIlEj6aBdRKl1tF6TBYEhzEko1lg+NGiKCodOncXb3bvQfPIjo+fPILS2F4847oc2cCaPFojoijSCWD40KIoJzBw+ie8MGDP7P/wAXt4AGjhzB2fffR/E99+CmRx6ByWZTnJRGCvf50Khw4bPP8O+XXsJgT49ePMMkHMbpd97ByddeQ2xoSFFCGmksH1IuNjQE31/+gtDJk9cYFMOpt9/GwNGj6QtGKcXyIaVEBJ+//z76Lt7v6Zpjw2F8vmdPGlJROrB8SK1YDOe6uiCJfpyKxVKbh9KG5UNKSSSC0PHjqmOQAiwfUioWieBCd3fiLzDyn2y24G+S1ErmumajEQW33pq6LJRWLB9SKjY4mPB+HIPBAMuX7v9NmYvlQ0pFgkFIJJLYYIMBpsLC1AaitGH5kFIXPvsM0UseKPlVeI1X9mD5kDIiguj58wl/7MopKICB13dlDZYPKZXw+T0AcsvLYeKdK7MGy4fUEcFQX1/Cw00FBbyyPYuwfEgZicUwcMmTUL6KwWSCgef5ZA3+JkmdWAxDp08nPNyYm5vCMJRuLB9SJ8kHpxROm5aiIKQCy4eUiV64kNT9eczFxSlMQ+nG8iFlQj4fogMDiQ02GJCTn5/aQJRWLB9SJtzbi9iFC4kNNhi+mChrsHxIiWQflGt1OmHhx66swvIhZaLnzyc8NsdmgzEvL4VpKN1YPqTM0NmzCY/NycuD0WpNYRpKN5YPqSGCc4cOJT6e+3uyDsuHlEnmY5e5qCiFSUgFlg8pIbEYJImbwdumTk1hGlKB5UNKRAcGEAuFEh7PLZ/sw/IhJcI+X1I7nI1mM28klmVYPqRE5Pz5L+7fnACDycSbiGUhlg+lnYgk9fA/c1ERrC5XChORCkmVT1NTE2bPno3CwkKUlJRg/vz56OrqihszODgIj8eD4uJi2Gw21NfXw+/3x43p7u5GXV0d8vPzUVJSgpUrVyKS6E3EKStE+vsTHmvMzUVOQUEK05AKSZVPW1sbPB4PPvzwQ7S0tGBoaAjz5s3DwCUXBz755JN48803sWXLFrS1teHEiRN48MEH9eXRaBR1dXUIh8PYs2cPNm3ahI0bN2L16tUjt1Y06oV7exMeazSbeQfDLGSQZC+yucSpU6dQUlKCtrY2fOtb30IgEMC4ceOwefNm/OhHPwIAHD16FFOnToXX68WcOXOwc+dOfP/738eJEyfgdDoBABs2bMAvfvELnDp1CpYE/pEFg0HY7XYEAgFovKdvxhER/Ou//gvBAwcSGp//jW9gym9/y7sYZoBk3ps39NsMBAIAgKKLh0Hb29sxNDSEmpoafcyUKVNQXl4Or9cLAPB6vZg+fbpePABQW1uLYDCIQ1c54zUUCiEYDMZNlNkkGk14bG55Oc9wzkLXXT6xWAzLly/HXXfdhWkX7zDn8/lgsVjgcDjixjqdTvh8Pn3MpcUzvHx42ZU0NTXBbrfrU1lZ2fXGplFAotGkbiKW/x//kcI0pMp1l4/H48HBgwfR3Nw8knmuqLGxEYFAQJ96enpS/jMpdWIXLiR1aYXJbk9hGlLFdD0vWrJkCbZv347du3djwoQJ+nyXy4VwOIy+vr64rR+/3w/XxUOlLpcL+/bti/t+w0fDXFc5nGq1WmHlFc1ZI3zqFEInTiQ83mA08gTDLJTUlo+IYMmSJdi6dSt27dqFioqKuOVVVVUwm81obW3V53V1daG7uxtutxsA4Ha70dnZid5Ljna0tLRA0zRUVlbeyLpQhpBIJOHnsxssFj4oMEslteXj8XiwefNmvPHGGygsLNT30djtduTl5cFut2PRokVYsWIFioqKoGkali5dCrfbjTlz5gAA5s2bh8rKSjz88MNYu3YtfD4fnn76aXg8Hm7dfE3EIpGEn1xhstlgLS1NcSJSIanyWb9+PQDg29/+dtz8V199FT/5yU8AAM8//zyMRiPq6+sRCoVQW1uLl19+WR+bk5OD7du3Y/HixXC73SgoKEBDQwOeffbZG1sTyhiRzz9PeKzBYoHJZkthGlIlqfJJ5JSg3NxcrFu3DuvWrbvqmJtvvhk7duxI5kdTFhk8eTLhsQajEQbTde2apFGOZ21R2p1P4hHJlL1YPpRWyZ5Qb5s6FYacnBSlIZVYPpRWEokk/qwu4Iur2XlZRVbib5XSKjY4iKGLl+UkglezZy+WD6VVJBjEYDJnqBsMPMEwS7F8KL1EEj7Hx2i1wjp+fIoDkSosH0qraCiUcPkYLBaWTxZj+VBaDZ05k/ARL4PJxBMMsxjLh9Kqv7Mz4fs3GwAe6cpi/M1SWiX1rK6xY3mOTxZj+VDaJPvUirzychh5aUXWYvlQ2kgkktRTK0wOB8Atn6zF/1YobWKhUNxTK0QEfeEw/l9/P+wWC75RWAjjJef0GM1mFTEpTVg+lDaxUAiDF+9gKCLoHhjAqgMH0BUIoMBkwv+55RY8VFGBnOEC4h0Msxo/dlH6GAz67TEEwP/t7MThvj5ERRAcGsJLR47g4MV7/RjMZhTccovCsJRqLB9KG7PDAcfs2frXwS89wSIciyF08ZE6hpwcWMaNS2s+Si+WD6WNwWjUn0RhAPAdlwumSz5W3aJpuPniSYXGvDyeYJjluM+H0qr4O9/B6ZYWRM+dQ8OkSSg0m/HfJ09ifF4eHrvlFpTk5gIAiu6+m4/MyXIsH0qr3PJyjF+wAMc3bYIpHMZ/TpyIH02ciOHtH4PBAFtlJZzz5/PxyFmO5UNpZTAaMe7eewEA/r/8BUOffw7DxWu9DGYztNtuw4RFi2C5+Ahuyl4sH0o7o9mMkro62G+/HcEDBxDy+WDMy4Nt6lTYpk5FTl6e6oiUBiwfUsJgNCL3ppuQe9NNqqOQIvxQTURKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUSKp81q9fjxkzZkDTNGiaBrfbjZ07d+rLBwcH4fF4UFxcDJvNhvr6evj9/rjv0d3djbq6OuTn56OkpAQrV65EJBIZmbUhooyRVPlMmDABa9asQXt7Oz766CPcc889eOCBB3Do0CEAwJNPPok333wTW7ZsQVtbG06cOIEHH3xQf300GkVdXR3C4TD27NmDTZs2YePGjVi9evXIrhURjX5yg8aMGSOvvPKK9PX1idlsli1btujLjhw5IgDE6/WKiMiOHTvEaDSKz+fTx6xfv140TZNQKHTVnzE4OCiBQECfenp6BIAEAoEbjU9EIygQCCT83rzufT7RaBTNzc0YGBiA2+1Ge3s7hoaGUFNTo4+ZMmUKysvL4fV6AQBerxfTp0+H0+nUx9TW1iIYDOpbT1fS1NQEu92uT2VlZdcbm4hGiaTLp7OzEzabDVarFU888QS2bt2KyspK+Hw+WCwWOByOuPFOpxM+nw8A4PP54opnePnwsqtpbGxEIBDQp56enmRjE9Eok/Q9nG+99VZ0dHQgEAjg9ddfR0NDA9ra2lKRTWe1WmG1WlP6M4govZIuH4vFgkmTJgEAqqqqsH//frzwwgt46KGHEA6H0dfXF7f14/f74XK5AAAulwv79u2L+37DR8OGxxDR18MNn+cTi8UQCoVQVVUFs9mM1tZWfVlXVxe6u7vhdrsBAG63G52dnejt7dXHtLS0QNM0VFZW3mgUIsogSW35NDY24r777kN5eTn6+/uxefNmvPfee3j77bdht9uxaNEirFixAkVFRdA0DUuXLoXb7cacOXMAAPPmzUNlZSUefvhhrF27Fj6fD08//TQ8Hg8/VhF9zSRVPr29vXjkkUdw8uRJ2O12zJgxA2+//Ta++93vAgCef/55GI1G1NfXIxQKoba2Fi+//LL++pycHGzfvh2LFy+G2+1GQUEBGhoa8Oyzz47sWhHRqGcQufis2gwSDAZht9sRCASgaZrqOER0UTLvTV7bRURKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKXFD5bNmzRoYDAYsX75cnzc4OAiPx4Pi4mLYbDbU19fD7/fHva67uxt1dXXIz89HSUkJVq5ciUgkciNRiCjDXHf57N+/H3/84x8xY8aMuPlPPvkk3nzzTWzZsgVtbW04ceIEHnzwQX15NBpFXV0dwuEw9uzZg02bNmHjxo1YvXr19a8FEWUeuQ79/f0yefJkaWlpkbvvvluWLVsmIiJ9fX1iNptly5Yt+tgjR44IAPF6vSIismPHDjEajeLz+fQx69evF03TJBQKXfHnDQ4OSiAQ0Keenh4BIIFA4HriE1GKBAKBhN+b17Xl4/F4UFdXh5qamrj57e3tGBoaips/ZcoUlJeXw+v1AgC8Xi+mT58Op9Opj6mtrUUwGMShQ4eu+POamppgt9v1qays7HpiE9EoknT5NDc34+OPP0ZTU9Nly3w+HywWCxwOR9x8p9MJn8+nj7m0eIaXDy+7ksbGRgQCAX3q6elJNjYRjTKmZAb39PRg2bJlaGlpQW5ubqoyXcZqtcJqtabt5xFR6iW15dPe3o7e3l7cfvvtMJlMMJlMaGtrw4svvgiTyQSn04lwOIy+vr641/n9frhcLgCAy+W67OjX8NfDY4go+yVVPnPnzkVnZyc6Ojr0adasWVi4cKH+Z7PZjNbWVv01XV1d6O7uhtvtBgC43W50dnait7dXH9PS0gJN01BZWTlCq0VEo11SH7sKCwsxbdq0uHkFBQUoLi7W5y9atAgrVqxAUVERNE3D0qVL4Xa7MWfOHADAvHnzUFlZiYcffhhr166Fz+fD008/DY/Hw49WRF8jSZVPIp5//nkYjUbU19cjFAqhtrYWL7/8sr48JycH27dvx+LFi+F2u1FQUICGhgY8++yzIx2FiEYxg4iI6hDJCgaDsNvtCAQC0DRNdRwiuiiZ9yav7SIiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlLCpDrA9RARAEAwGFSchIguNfyeHH6PXktGls+ZM2cAAGVlZYqTENGV9Pf3w263X3NMRpZPUVERAKC7u/srV3C0CQaDKCsrQ09PDzRNUx0nYcydXpmaW0TQ39+P0tLSrxybkeVjNH6xq8put2fUL+ZSmqZlZHbmTq9MzJ3oBgF3OBOREiwfIlIiI8vHarXimWeegdVqVR0laZmanbnTK1NzJ8MgiRwTIyIaYRm55UNEmY/lQ0RKsHyISAmWDxEpwfIhIiUysnzWrVuHiRMnIjc3F9XV1di3b5/SPLt378b999+P0tJSGAwGbNu2LW65iGD16tUYP3488vLyUFNTg08//TRuzNmzZ7Fw4UJomgaHw4FFixbh3LlzKc3d1NSE2bNno7CwECUlJZg/fz66urrixgwODsLj8aC4uBg2mw319fXw+/1xY7q7u1FXV4f8/HyUlJRg5cqViEQiKcu9fv16zJgxQz/71+12Y+fOnaM685WsWbMGBoMBy5cvz7jsI0IyTHNzs1gsFvnTn/4khw4dkscee0wcDof4/X5lmXbs2CG/+tWv5K9//asAkK1bt8YtX7Nmjdjtdtm2bZv84x//kB/84AdSUVEhFy5c0Mfce++9MnPmTPnwww/l/fffl0mTJsmCBQtSmru2tlZeffVVOXjwoHR0dMj3vvc9KS8vl3PnzuljnnjiCSkrK5PW1lb56KOPZM6cOXLnnXfqyyORiEybNk1qamrkwIEDsmPHDhk7dqw0NjamLPff/vY3+fvf/y7//Oc/paurS375y1+K2WyWgwcPjtrMX7Zv3z6ZOHGizJgxQ5YtW6bPz4TsIyXjyueOO+4Qj8ejfx2NRqW0tFSampoUpvpfXy6fWCwmLpdLnnvuOX1eX1+fWK1Wee2110RE5PDhwwJA9u/fr4/ZuXOnGAwGOX78eNqy9/b2CgBpa2vTc5rNZtmyZYs+5siRIwJAvF6viHxRvEajUXw+nz5m/fr1ommahEKhtGUfM2aMvPLKKxmRub+/XyZPniwtLS1y99136+WTCdlHUkZ97AqHw2hvb0dNTY0+z2g0oqamBl6vV2Gyqzt27Bh8Pl9cZrvdjurqaj2z1+uFw+HArFmz9DE1NTUwGo3Yu3dv2rIGAgEA/3vXgPb2dgwNDcVlnzJlCsrLy+OyT58+HU6nUx9TW1uLYDCIQ4cOpTxzNBpFc3MzBgYG4Ha7MyKzx+NBXV1dXEYgM/6+R1JGXdV++vRpRKPRuL94AHA6nTh69KiiVNfm8/kA4IqZh5f5fD6UlJTELTeZTCgqKtLHpFosFsPy5ctx1113Ydq0aXoui8UCh8NxzexXWrfhZanS2dkJt9uNwcFB2Gw2bN26FZWVlejo6Bi1mQGgubkZH3/8Mfbv33/ZstH8950KGVU+lDoejwcHDx7EBx98oDpKQm699VZ0dHQgEAjg9ddfR0NDA9ra2lTHuqaenh4sW7YMLS0tyM3NVR1HuYz62DV27Fjk5ORctvff7/fD5XIpSnVtw7muldnlcqG3tzdueSQSwdmzZ9OyXkuWLMH27dvx7rvvYsKECfp8l8uFcDiMvr6+a2a/0roNL0sVi8WCSZMmoaqqCk1NTZg5cyZeeOGFUZ25vb0dvb29uP3222EymWAymdDW1oYXX3wRJpMJTqdz1GZPhYwqH4vFgqqqKrS2turzYrEYWltb4Xa7FSa7uoqKCrhcrrjMwWAQe/fu1TO73W709fWhvb1dH7Nr1y7EYjFUV1enLJuIYMmSJdi6dSt27dqFioqKuOVVVVUwm81x2bu6utDd3R2XvbOzM648W1paoGkaKisrU5b9y2KxGEKh0KjOPHfuXHR2dqKjo0OfZs2ahYULF+p/Hq3ZU0L1Hu9kNTc3i9VqlY0bN8rhw4fl8ccfF4fDEbf3P936+/vlwIEDcuDAAQEgv/vd7+TAgQPy73//W0S+ONTucDjkjTfekE8++UQeeOCBKx5qv+2222Tv3r3ywQcfyOTJk1N+qH3x4sVit9vlvffek5MnT+rT+fPn9TFPPPGElJeXy65du+Sjjz4St9stbrdbXz586HfevHnS0dEhb731lowbNy6lh36feuopaWtrk2PHjsknn3wiTz31lBgMBnnnnXdGbearufRoV6Zlv1EZVz4iIn/4wx+kvLxcLBaL3HHHHfLhhx8qzfPuu+8KgMumhoYGEfnicPuqVavE6XSK1WqVuXPnSldXV9z3OHPmjCxYsEBsNptomiaPPvqo9Pf3pzT3lTIDkFdffVUfc+HCBfnZz34mY8aMkfz8fPnhD38oJ0+ejPs+n332mdx3332Sl5cnY8eOlZ///OcyNDSUstw//elP5eabbxaLxSLjxo2TuXPn6sUzWjNfzZfLJ5Oy3yjez4eIlMiofT5ElD1YPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iU+P963GJuLiVjrwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.1095],\n",
       "         [-0.0109]], grad_fn=<TanhBackward0>),\n",
       " tensor([[1.1577],\n",
       "         [0.9720]], grad_fn=<ExpBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class ModelAction(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 64),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(64, 64),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "        self.sigma = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.s(state)\n",
    "        return self.mu(state), self.sigma(state).exp()\n",
    "\n",
    "\n",
    "model_action = ModelAction()\n",
    "\n",
    "model_action(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0457],\n",
       "        [-0.0644]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_value_next.load_state_dict(model_value.state_dict())\n",
    "\n",
    "model_value(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\appDir\\python3.10\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-3.0416506867185897"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据概率采样\n",
    "        mu, sigma = model_action(torch.FloatTensor(state).reshape(1, 3))\n",
    "        action = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Administrator\\AppData\\Local\\Temp\\ipykernel_10316\\3624659836.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
      "  state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([-0.9900937 ,  0.14040813,  0.41601098], dtype=float32),\n",
       "  -1.1293330118035945,\n",
       "  -0.12820265590390023,\n",
       "  array([-0.9915868 ,  0.12944353,  0.22131707], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据池\n",
    "class Pool:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.pool = []\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pool)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.pool[i]\n",
    "\n",
    "    #更新动作池\n",
    "    def update(self):\n",
    "        #每次更新不少于N条新数据\n",
    "        old_len = len(self.pool)\n",
    "        while len(pool) - old_len < 200:\n",
    "            self.pool.extend(play()[0])\n",
    "\n",
    "        #只保留最新的N条数据\n",
    "        self.pool = self.pool[-2_0000:]\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def sample(self):\n",
    "        data = random.sample(self.pool, 64)\n",
    "\n",
    "        state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 3)\n",
    "        action = torch.FloatTensor([i[1] for i in data]).reshape(-1, 1)\n",
    "        reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)\n",
    "        next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 3)\n",
    "        over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "pool = Pool()\n",
    "pool.update()\n",
    "state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "len(pool), pool[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=5e-4)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=5e-3)\n",
    "\n",
    "\n",
    "def soft_update(_from, _to):\n",
    "    for _from, _to in zip(_from.parameters(), _to.parameters()):\n",
    "        value = _to.data * 0.995 + _from.data * 0.005\n",
    "        _to.data.copy_(value)\n",
    "\n",
    "\n",
    "def get_action_entropy(state):\n",
    "    mu, sigma = model_action(torch.FloatTensor(state).reshape(-1, 3))\n",
    "    dist = torch.distributions.Normal(mu, sigma)\n",
    "\n",
    "    action = dist.rsample()\n",
    "\n",
    "    return action, sigma\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.059022821485996246"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, action, reward, next_state, over):\n",
    "    requires_grad(model_value, True)\n",
    "    requires_grad(model_action, False)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        #计算动作和熵\n",
    "        next_action, entropy = get_action_entropy(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        input = torch.cat([next_state, next_action], dim=1)\n",
    "        target = model_value_next(input)\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    target = target + 5e-3 * entropy\n",
    "    target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_value.step()\n",
    "    optimizer_value.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_value(state, action, reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08403986692428589"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state):\n",
    "    requires_grad(model_value, False)\n",
    "    requires_grad(model_action, True)\n",
    "\n",
    "    #计算action和熵\n",
    "    action, entropy = get_action_entropy(state)\n",
    "\n",
    "    #计算value\n",
    "    value = model_value(torch.cat([state, action], dim=1))\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    loss = -(value + 5e-3 * entropy).mean()\n",
    "\n",
    "    #使用model_value计算model_action的loss\n",
    "    loss.backward()\n",
    "    optimizer_action.step()\n",
    "    optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 39.62571140958601\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(10):\n",
    "        #更新N条数据\n",
    "        pool.update()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(20):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "            #训练\n",
    "            train_value(state, action, reward, next_state, over)\n",
    "            train_action(state)\n",
    "            soft_update(model_value, model_value_next)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, len(pool), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAikElEQVR4nO3de3RU1b0H8O+ZTGaSkMyEkGTGSAJRrBB5VB7CFG9fRKKkPmpsLaWYKrWFBhYP69KgwOrLUFz1WYSuasXrVfBCG9oGkcYgoZbhFYiGV0orNlmQySAhMyGYmWTmd/+wOdeBgJmQZM/E72etsxZz9p4zvwOc79rn7DlnNBEREBH1M4PqAojo84nhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESigLn9WrV2P48OGIi4vD5MmTsXfvXlWlEJECSsLnjTfewJIlS7BixQocOHAA48aNQ15eHtxut4pyiEgBTcWNpZMnT8akSZPwm9/8BgAQDAaRmZmJBQsW4NFHH/3M9weDQZw6dQpJSUnQNK2vyyWibhIRtLS0ICMjAwbD5cc2xn6qSef3+1FVVYXi4mJ9ncFgQG5uLpxOZ5fv8fl88Pl8+uuTJ08iJyenz2slop6pr6/H0KFDL9un38Pno48+QiAQgM1mC1lvs9lw7NixLt9TUlKCn/70pxetr6+vh8Vi6ZM6iSh8Xq8XmZmZSEpK+sy+/R4+PVFcXIwlS5borzt30GKxMHyIIlB3Lof0e/ikpqYiJiYGjY2NIesbGxtht9u7fI/ZbIbZbO6P8oion/T7bJfJZMKECRNQUVGhrwsGg6ioqIDD4ejvcohIESWnXUuWLEFhYSEmTpyIm266Cc888wxaW1tx//33qyiHiBRQEj733nsvTp8+jeXLl8PlcuGLX/wi3nrrrYsuQhPRwKXkez5Xyuv1wmq1wuPx8IIzUQQJ59jkvV1EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESYYfPzp07cfvttyMjIwOapmHz5s0h7SKC5cuX46qrrkJ8fDxyc3Nx/PjxkD5NTU2YNWsWLBYLkpOTMWfOHJw7d+6KdoSIokvY4dPa2opx48Zh9erVXbavWrUKzz33HNauXYs9e/Zg0KBByMvLQ1tbm95n1qxZOHz4MMrLy1FWVoadO3fihz/8Yc/3goiij1wBAFJaWqq/DgaDYrfb5cknn9TXNTc3i9lslvXr14uIyJEjRwSA7Nu3T++zdetW0TRNTp482a3P9Xg8AkA8Hs+VlE9EvSycY7NXr/mcOHECLpcLubm5+jqr1YrJkyfD6XQCAJxOJ5KTkzFx4kS9T25uLgwGA/bs2dPldn0+H7xeb8hCRNGtV8PH5XIBAGw2W8h6m82mt7lcLqSnp4e0G41GpKSk6H0uVFJSAqvVqi+ZmZm9WTYRKRAVs13FxcXweDz6Ul9fr7okIrpCvRo+drsdANDY2BiyvrGxUW+z2+1wu90h7R0dHWhqatL7XMhsNsNisYQsRBTdejV8srOzYbfbUVFRoa/zer3Ys2cPHA4HAMDhcKC5uRlVVVV6n+3btyMYDGLy5Mm9WQ4RRTBjuG84d+4c/vnPf+qvT5w4gerqaqSkpCArKwuLFi3CL37xC1x33XXIzs7GsmXLkJGRgbvuugsAMGrUKNx666148MEHsXbtWrS3t2P+/Pn4zne+g4yMjF7bMSKKcOFOpb3zzjsC4KKlsLBQRD6Zbl+2bJnYbDYxm80ybdo0qa2tDdnGmTNnZObMmZKYmCgWi0Xuv/9+aWlp6XYNnGonikzhHJuaiIjC7OsRr9cLq9UKj8fD6z9EESScYzMqZruIaOBh+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRJG1QUQAUAwGEQwGAQAaJqmL52vaeAJa+RTUlKCSZMmISkpCenp6bjrrrtQW1sb0qetrQ1FRUUYMmQIEhMTUVBQgMbGxpA+dXV1yM/PR0JCAtLT0/Hwww+jo6PjyveGotZbb70Fh8OBr371qygoKMCiRYuwZs0a7Ny5E263G+3t7RAR1WVSLwpr5FNZWYmioiJMmjQJHR0dWLp0KaZPn44jR45g0KBBAIDFixdjy5Yt2LhxI6xWK+bPn4+7774bf//73wEAgUAA+fn5sNvt2LVrFxoaGnDfffchNjYWTzzxRO/vIUUFk8kEk8mElpYW1NXVYefOnfD7/TAajbDb7bj55ptxxx134Oabb0ZycjIMBl4xiHpyBdxutwCQyspKERFpbm6W2NhY2bhxo97n6NGjAkCcTqeIiLz55ptiMBjE5XLpfdasWSMWi0V8Pl+3Ptfj8QgA8Xg8V1I+RRC/3y9er1fOnj0rp0+fliNHjsjmzZtlxYoVMm3aNElPT5ekpCSZOHGiPPvss3Ly5EkJBAKqy6YLhHNsXlH4HD9+XABITU2NiIhUVFQIADl79mxIv6ysLHnqqadERGTZsmUybty4kPYPPvhAAMiBAwe6/Jy2tjbxeDz6Ul9fz/D5HAgGgxIIBKS1tVXee+89+eUvfyljxoyRhIQEGT9+vLz66qty7tw5CQaDqkul/wgnfHo8dg0Gg1i0aBGmTp2K0aNHAwBcLhdMJhOSk5ND+tpsNrhcLr2PzWa7qL2zrSslJSWwWq36kpmZ2dOyKYpomgaDwYCEhASMGTMGjz76KMrLy1FSUoLW1lb86Ec/wpw5c3D8+HFeD4pCPQ6foqIiHDp0CBs2bOjNerpUXFwMj8ejL/X19X3+mRRZOoMoPT0d8+fPx5YtWzB79mxs2bIFBQUFePvttxEIBFSXSWHoUfjMnz8fZWVleOeddzB06FB9vd1uh9/vR3Nzc0j/xsZG2O12vc+Fs1+drzv7XMhsNsNisYQs9PnUGULXXnstnnnmGaxduxYtLS34/ve/jzfeeIOzYlEkrPAREcyfPx+lpaXYvn07srOzQ9onTJiA2NhYVFRU6Otqa2tRV1cHh8MBAHA4HKipqYHb7db7lJeXw2KxICcn50r2hT5n4uLiMHPmTKxfvx52ux2LFy/G66+/zhFQtAjnYtK8efPEarXKjh07pKGhQV/Onz+v95k7d65kZWXJ9u3bZf/+/eJwOMThcOjtHR0dMnr0aJk+fbpUV1fLW2+9JWlpaVJcXNztOjjbRZ8WDAbl0KFD4nA4xGazyaZNmzgTpkifzXYB6HJ5+eWX9T4ff/yx/PjHP5bBgwdLQkKCfPOb35SGhoaQ7Xz44Ydy2223SXx8vKSmpspDDz0k7e3t3a6D4UMXCgaD8t5778nYsWNl+PDh4nQ6OQumQDjHpiYSfSfIXq8XVqsVHo+H139IJyLYsWMHZs2ahezsbGzatAl2u523Z/SjcI5Nfk2UBgxN0/DlL38ZS5cuxcGDB/H000/ztp0IxvChASUmJgbf//73MWPGDPz+97/Hrl27OPsVoRg+NOAkJiZi6dKlMJvNeOaZZ9Da2qq6JOoCw4cGpHHjxuF73/setm/fjp07d3L0E4EYPjQgGQwG3H///bBarVi3bh18Pp/qkugCDB8akDRNw4gRIzBjxgxUVlaitraWo58Iw/ChASsmJgb33HMP/H4/tm3bxvCJMAwf6hafz4fz589H1QGsaRomTJiA7OxsbNu2jadeEYbhQ59JRFBZWYlt27apLiVsSUlJuPnmm3Hs2DE+DSHCMHzoM3V0dOAPf/gDXnvtNbS3t/fadiUYREdLC/xnzqC9uRlBn6/XR1YxMTGYMmUKWlpacOTIkagauQ10/PUK+kxutxs7duyA1+vFP//5zyt++oCIoL2pCae3boVn7174P/oIBrMZCSNGIH3GDCSNHQstJqZXatc0DaNHj0ZMTAwOHz6MO++8s1e2S1eO4UOXJSJwOp2oq6uD3+/Hli1bMGrUqB7fLyUi8J06hQ+ffRattbXAf0YigXPn4DlzBudqapBx331Iy8vrtQCy2+1ISUnRZ7x4r1dk4GkXXVZ7ezvKysrg9/sBAKWlpfB4PD3eXuD8edT99rdoPXYMEgzirM+H/R99hONeL4IiCJw/j5P//d/wHDjQa6dI8fHxsNvtqK+v52lXBGH40GU1NjZi586d+uv33nsPBw8e7PFB7Nm3Dy3vvw8RQV1rKxbu3Yui3bvxo127sOHECQREEDx/Ho2lpQi2tfXKPhiNRgwZMgRer7dXtke9g+FDlyQi+Nvf/oaGhgZ9XVtbGzZt2tSjpwVKMIiW998HgkEIgF/V1OBIczMCIvC2t+M3R4/i0NmzAIDzH3yAwMcf98p+xMTEYNCgQRz1RBhe86FLam9vx5YtW5CWlgb55MFzyMrKQkVFBRoaGsL+FZH25mZ49u3TX3svmDnzB4Pw9dEjUPkjg5GH/yJ0SadOnYLf78err76K8ePHIzU1FevWrUNeXh52794d9khCAgEE/3PtSAPwNbsdxk9d/P2CxYJhiYm9uQuffK4Iv2AYgTjyoUuKiYnBU089BbvdjpdeegltbW1ITk5GSUkJ/v3vfyMYDCImnBmpT4WVpmkoHDECSbGxeLuhAVfFx+PBL3wB6XFxvb4fHR0d8Hg8iI2N7fVtU88xfOiSPv2zSOnp6WhpaYHH40FaWhpGjhwZ/gaDwZCXRoMB3xo+HPcMH47O8U9fTIP7/X643W5kZWX1+rap53jaRZekaZq+DBs2DB6PB6dPnw5ZHw4RCRn9dH6GoYfb6y6v1wu3241rrrmG3/GJIAwf6pZRo0aho6MDR48e7flGLhj59AcRwQcffIDW1lbk5OQwfCIIw4e65ZprrkFqaiqcTieCPQ0RRVPdBw8ehNFoxA033KDk86lrDB/qFrvdjpycHFRVVaGpqalH21DxPRufz4d3330XNpsNI0eO5MgngjB8qFvi4+Px9a9/HcePH+/5qZeC065Tp05h//79mDp1KpKTk/v98+nSGD7ULZqmIS8vD3FxcSgtLe3ZN5y7uODcl0QEb7/9NjweD/Lz88P7WgD1OYYPdduoUaMwZcoUlJWV4eTJk+FvoJ+D59y5c1i/fj2GDRuGm2++madcEYbhQ91mNptRWFiIU6dOYePGjeFfeO7n067t27fjwIED+Na3voWUlJR+/Wz6bAwf6jZN03DLLbfgxhtvxEsvvYS6urqwLiL35wVnr9eL559/HqmpqfjOd77DU64IxPChsFitVixYsAAnT57ECy+8EN5jVcMY+RhiY3t8mhQMBrFhwwY4nU488MADuOaaa3q0HepbDB8KW35+PqZPn45169ahsrKy2yMaEUF3xz5xQ4fC0MP7vI4ePYpVq1YhJycHDzzwAEc9EYrhQ2HRNA0JCQl47LHHMGjQIDz22GPdf0JgGLNdWkwM0IPbN86ePYvHHnsMTU1NWLFiBex2e1jboP7D8KGwaZqGMWPG4PHHH8exY8fwyCOPwOv1fnYAhXHaFW74iAja2trwq1/9CuXl5Zg/fz7y8vI4wxXBeFc79UhMTAy++93vora2FqtXr0Z6ejp+8YtfIDEx8ZIHfFgXnA0GdDc2RAR+vx/PP/881qxZg9tvvx0PPfQQjEb+945k/NehHouPj0dxcTGamprw0ksvAQBWrFiBwYMHdx1AYYSPZjB0a+TTOeJ57rnnsHLlSkyZMgVPPvkkrFYrRz0RjuFDVyQ5ORmrVq2CiOCll17C6dOn8cQTT2DYsGEXHfzSy6ddIoKmpiY88cQT+N3vfgeHw4EXXngBQ4cOZfBEAV7zoSuiaRoGDx6MX//615g7dy7Kysrw7W9/G2+//fbF0/Dh3F7xGc9cDgQCOHjwIGbPno01a9ZgxowZePHFF/nMnigSVvisWbMGY8eOhcVigcVigcPhwNatW/X2trY2FBUVYciQIUhMTERBQQEaGxtDtlFXV4f8/HwkJCQgPT0dDz/8MDo6Onpnb0gJTdNgtVrx85//HKtWrUJDQwNmzpyJpUuX4sMPP9QfPh/WaVcXI5/O7bjdbjz11FO466674HQ68ZOf/ARr167liCfKhBU+Q4cOxcqVK1FVVYX9+/fj61//Ou68804cPnwYALB48WL85S9/wcaNG1FZWYlTp07h7rvv1t8fCASQn58Pv9+PXbt24ZVXXsG6deuwfPny3t0r6neapiEuLg4PPvggNm3ahMmTJ2P16tXIy8tDSUkJ/vWvf6EjjC8kagaDHiQigkAgAJfLhRdffBH5+flYsWIFrr76arz22mtYtmwZr/FEIU2u8DvvKSkpePLJJ3HPPfcgLS0Nr7/+Ou655x4AwLFjxzBq1Cg4nU5MmTIFW7duxTe+8Q2cOnUKNpsNALB27Vo88sgjOH36NEwmU5ef4fP5Qn59wOv1IjMzEx6PBxaL5UrKpz4gIvB6vdi8eTNWr16Nw4cPIy0tDVNHjsQXz55FTlISks1mxH7qEaoXvn/w176GjLlz0ezxoLq6Gtu2bcO2bdtQV1eH4cOH4wc/+AFmzZqFtLQ0hk4E8Xq9sFqt3To2e3zBORAIYOPGjWhtbYXD4UBVVRXa29uRm5ur9xk5ciSysrL08HE6nRgzZowePACQl5eHefPm4fDhw7jxxhu7/KySkhL89Kc/7Wmp1M86T8Puu+8+zJgxA9u2bcP69evx1927sdnjQaLRiKzERGQnJuKqhAQMNpkQbzRCA+ALBNDc3o7m5macKi/HsWPHcObMGRiNRowZMwYLFizAnXfeiauuuoq/xRXlwg6fmpoaOBwOtLW1ITExEaWlpcjJyUF1dTVMJtNFD2yy2WxwuVwAAJfLFRI8ne2dbZdSXFyMJUuW6K87Rz4U2TRNQ1paGmbNmoW7774b1WVl2PzEEzjgduODlhZ80NICXyCA4Kduu9AAGDQNcQkJSLHbccMNN2Dq1Kn46le/ihtuuAFJSUkc6QwQYYfP9ddfj+rqang8HmzatAmFhYWorKzsi9p0ZrMZZrO5Tz+D+k7nLRkj7Xbcm52NgqFDcb6jA6d9PjT5fDjX3g7/f6bhYw0GDDIace0tt2D8ggWwWq0wGo19+usWpEbY4WMymTBixAgAwIQJE7Bv3z48++yzuPfee+H3+9Hc3Bwy+mlsbNTvr7Hb7di7d2/I9jpnw3gPzsD3cX09EAjAaDDAYjLBYjLh2qSkLvumZWQgNTWVgTOAXfFJczAYhM/nw4QJExAbG4uKigq9rba2FnV1dXA4HAAAh8OBmpoauN1uvU95eTksFgtycnKutBSKdGHObTB4BrawRj7FxcW47bbbkJWVhZaWFrz++uvYsWMHtm3bBqvVijlz5mDJkiVISUmBxWLBggUL4HA4MGXKFADA9OnTkZOTg9mzZ2PVqlVwuVx4/PHHUVRUxNMqos+ZsMLH7XbjvvvuQ0NDA6xWK8aOHYtt27bhlltuAQA8/fTTMBgMKCgogM/nQ15eHl544QX9/TExMSgrK8O8efPgcDgwaNAgFBYW4mc/+1nv7hURRbwr/p6PCuF8l4Aix8n/+R+4/vd/u9U3bcYMZM2d28cVUW8L59jkFyWISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UMRyRAfr7oE6mMMH4o8BgPihw1TXQX1MYYPRSSNP3E84DF8KCIxfAY+hg9FHk1j+HwOMHwoIjF8Bj6GD0Umhs+Ax/ChiKOBI5/PA4YPRSSGz8DH8KHIwwvOnwsMH4pIDJ+Bj+FDEYnhM/AxfCgiMXwGPoYPRSaGz4DH8KGIxB8MHPgYPtRv4jIyujWiMaWmIoaP1BjwGD7Ub6yTJmHQF75w+U4xMUi97TYYBw/un6JIGYYP9ZuYxERkPvgg4jIzL9EhBqm5uUi79Vaedn0OMHyo32iahoRrr8U1jzyCwf/1XzBarYDBAC02Fuarr8bVs2dj6AMP8JTrcyKs32onulKapiEuMxPDFy+Gv7ERHR4PNKMRprQ0GJOTVZdH/YjhQ/1O0zRoRiPirr4auPpq1eWQIjztIiIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhIiSsKn5UrV0LTNCxatEhf19bWhqKiIgwZMgSJiYkoKChAY2NjyPvq6uqQn5+PhIQEpKen4+GHH0ZHR8eVlEJEUabH4bNv3z789re/xdixY0PWL168GH/5y1+wceNGVFZW4tSpU7j77rv19kAggPz8fPj9fuzatQuvvPIK1q1bh+XLl/d8L4go+kgPtLS0yHXXXSfl5eXyla98RRYuXCgiIs3NzRIbGysbN27U+x49elQAiNPpFBGRN998UwwGg7hcLr3PmjVrxGKxiM/n6/Lz2traxOPx6Et9fb0AEI/H05PyiaiPeDyebh+bPRr5FBUVIT8/H7m5uSHrq6qq0N7eHrJ+5MiRyMrKgtPpBAA4nU6MGTMGNptN75OXlwev14vDhw93+XklJSWwWq36knmpu6KJKGqEHT4bNmzAgQMHUFJSclGby+WCyWRC8gU3CNpsNrhcLr3Pp4Ons72zrSvFxcXweDz6Ul9fH27ZRBRhwrqxtL6+HgsXLkR5eTni4uL6qqaLmM1mmM3mfvs8Iup7YY18qqqq4Ha7MX78eBiNRhiNRlRWVuK5556D0WiEzWaD3+9Hc3NzyPsaGxtht9sBAHa7/aLZr87XnX2IaOALK3ymTZuGmpoaVFdX68vEiRMxa9Ys/c+xsbGoqKjQ31NbW4u6ujo4HA4AgMPhQE1NDdxut96nvLwcFosFOTk5vbRbRBTpwjrtSkpKwujRo0PWDRo0CEOGDNHXz5kzB0uWLEFKSgosFgsWLFgAh8OBKVOmAACmT5+OnJwczJ49G6tWrYLL5cLjjz+OoqIinloRfY70+sPEnn76aRgMBhQUFMDn8yEvLw8vvPCC3h4TE4OysjLMmzcPDocDgwYNQmFhIX72s5/1dilEFME0ERHVRYTL6/XCarXC4/HAYrGoLoeI/iOcY5P3dhGREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESlhVF1AT4gIAMDr9SquhIg+rfOY7DxGLycqw+fMmTMAgMzMTMWVEFFXWlpaYLVaL9snKsMnJSUFAFBXV/eZOxhpvF4vMjMzUV9fD4vForqcbmPd/Sta6xYRtLS0ICMj4zP7RmX4GAyfXKqyWq1R9Q/zaRaLJSprZ939Kxrr7u6AgBeciUgJhg8RKRGV4WM2m7FixQqYzWbVpYQtWmtn3f0rWusOhybdmRMjIuplUTnyIaLox/AhIiUYPkSkBMOHiJRg+BCRElEZPqtXr8bw4cMRFxeHyZMnY+/evUrr2blzJ26//XZkZGRA0zRs3rw5pF1EsHz5clx11VWIj49Hbm4ujh8/HtKnqakJs2bNgsViQXJyMubMmYNz5871ad0lJSWYNGkSkpKSkJ6ejrvuugu1tbUhfdra2lBUVIQhQ4YgMTERBQUFaGxsDOlTV1eH/Px8JCQkID09HQ8//DA6Ojr6rO41a9Zg7Nix+rd/HQ4Htm7dGtE1d2XlypXQNA2LFi2Kutp7hUSZDRs2iMlkkt///vdy+PBhefDBByU5OVkaGxuV1fTmm2/KY489Jn/84x8FgJSWloa0r1y5UqxWq2zevFnee+89ueOOOyQ7O1s+/vhjvc+tt94q48aNk927d8vf/vY3GTFihMycObNP687Ly5OXX35ZDh06JNXV1TJjxgzJysqSc+fO6X3mzp0rmZmZUlFRIfv375cpU6bIl770Jb29o6NDRo8eLbm5uXLw4EF58803JTU1VYqLi/us7j//+c+yZcsW+cc//iG1tbWydOlSiY2NlUOHDkVszRfau3evDB8+XMaOHSsLFy7U10dD7b0l6sLnpptukqKiIv11IBCQjIwMKSkpUVjV/7swfILBoNjtdnnyySf1dc3NzWI2m2X9+vUiInLkyBEBIPv27dP7bN26VTRNk5MnT/Zb7W63WwBIZWWlXmdsbKxs3LhR73P06FEBIE6nU0Q+CV6DwSAul0vvs2bNGrFYLOLz+fqt9sGDB8uLL74YFTW3tLTIddddJ+Xl5fKVr3xFD59oqL03RdVpl9/vR1VVFXJzc/V1BoMBubm5cDqdCiu7tBMnTsDlcoXUbLVaMXnyZL1mp9OJ5ORkTJw4Ue+Tm5sLg8GAPXv29FutHo8HwP8/NaCqqgrt7e0htY8cORJZWVkhtY8ZMwY2m03vk5eXB6/Xi8OHD/d5zYFAABs2bEBrayscDkdU1FxUVIT8/PyQGoHo+PvuTVF1V/tHH32EQCAQ8hcPADabDceOHVNU1eW5XC4A6LLmzjaXy4X09PSQdqPRiJSUFL1PXwsGg1i0aBGmTp2K0aNH63WZTCYkJydftvau9q2zra/U1NTA4XCgra0NiYmJKC0tRU5ODqqrqyO2ZgDYsGEDDhw4gH379l3UFsl/330hqsKH+k5RUREOHTqEd999V3Up3XL99dejuroaHo8HmzZtQmFhISorK1WXdVn19fVYuHAhysvLERcXp7oc5aLqtCs1NRUxMTEXXf1vbGyE3W5XVNXlddZ1uZrtdjvcbndIe0dHB5qamvplv+bPn4+ysjK88847GDp0qL7ebrfD7/ejubn5srV3tW+dbX3FZDJhxIgRmDBhAkpKSjBu3Dg8++yzEV1zVVUV3G43xo8fD6PRCKPRiMrKSjz33HMwGo2w2WwRW3tfiKrwMZlMmDBhAioqKvR1wWAQFRUVcDgcCiu7tOzsbNjt9pCavV4v9uzZo9fscDjQ3NyMqqoqvc/27dsRDAYxefLkPqtNRDB//nyUlpZi+/btyM7ODmmfMGECYmNjQ2qvra1FXV1dSO01NTUh4VleXg6LxYKcnJw+q/1CwWAQPp8vomueNm0aampqUF1drS8TJ07ErFmz9D9Hau19QvUV73Bt2LBBzGazrFu3To4cOSI//OEPJTk5OeTqf39raWmRgwcPysGDBwWAPPXUU3Lw4EH597//LSKfTLUnJyfLn/70J3n//fflzjvv7HKq/cYbb5Q9e/bIu+++K9ddd12fT7XPmzdPrFar7NixQxoaGvTl/Pnzep+5c+dKVlaWbN++Xfbv3y8Oh0McDofe3jn1O336dKmurpa33npL0tLS+nTq99FHH5XKyko5ceKEvP/++/Loo4+Kpmny17/+NWJrvpRPz3ZFW+1XKurCR0Tk+eefl6ysLDGZTHLTTTfJ7t27ldbzzjvvCICLlsLCQhH5ZLp92bJlYrPZxGw2y7Rp06S2tjZkG2fOnJGZM2dKYmKiWCwWuf/++6WlpaVP6+6qZgDy8ssv630+/vhj+fGPfyyDBw+WhIQE+eY3vykNDQ0h2/nwww/ltttuk/j4eElNTZWHHnpI2tvb+6zuBx54QIYNGyYmk0nS0tJk2rRpevBEas2XcmH4RFPtV4rP8yEiJaLqmg8RDRwMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RK/B+bYXbuJ/K/uQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-23.219147946687816"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
